{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8ag5ANQPJ_j9"
   },
   "source": [
    "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
    "\n",
    "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
    "\n",
    "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
    "\n",
    "---\n",
    "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
    "  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
    "  - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n",
    "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
    "\n",
    "  - [HuggingFace datasets](https://github.com/huggingface/datasets)\n",
    "  - [HuggingFace transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fqlsVTj7McZ3"
   },
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OIhHrRL-MnKK"
   },
   "outputs": [],
   "source": [
    "!pip install pytorch-lightning datasets transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6yuQT_ZQMpCg"
   },
   "outputs": [],
   "source": [
    "from argparse import ArgumentParser\n",
    "from datetime import datetime\n",
    "from typing import Optional\n",
    "\n",
    "import datasets\n",
    "import numpy as np\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import (\n",
    "    AdamW,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoConfig,\n",
    "    AutoTokenizer,\n",
    "    get_linear_schedule_with_warmup,\n",
    "    glue_compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9ORJfiuiNZ_N"
   },
   "source": [
    "## GLUE DataModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jW9xQhZxMz1G"
   },
   "outputs": [],
   "source": [
    "class GLUEDataModule(pl.LightningDataModule):\n",
    "\n",
    "    task_text_field_map = {\n",
    "        'cola': ['sentence'],\n",
    "        'sst2': ['sentence'],\n",
    "        'mrpc': ['sentence1', 'sentence2'],\n",
    "        'qqp': ['question1', 'question2'],\n",
    "        'stsb': ['sentence1', 'sentence2'],\n",
    "        'mnli': ['premise', 'hypothesis'],\n",
    "        'qnli': ['question', 'sentence'],\n",
    "        'rte': ['sentence1', 'sentence2'],\n",
    "        'wnli': ['sentence1', 'sentence2'],\n",
    "        'ax': ['premise', 'hypothesis']\n",
    "    }\n",
    "\n",
    "    glue_task_num_labels = {\n",
    "        'cola': 2,\n",
    "        'sst2': 2,\n",
    "        'mrpc': 2,\n",
    "        'qqp': 2,\n",
    "        'stsb': 1,\n",
    "        'mnli': 3,\n",
    "        'qnli': 2,\n",
    "        'rte': 2,\n",
    "        'wnli': 2,\n",
    "        'ax': 3\n",
    "    }\n",
    "\n",
    "    loader_columns = [\n",
    "        'datasets_idx',\n",
    "        'input_ids',\n",
    "        'token_type_ids',\n",
    "        'attention_mask',\n",
    "        'start_positions',\n",
    "        'end_positions',\n",
    "        'labels'\n",
    "    ]\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_name_or_path: str,\n",
    "        task_name: str ='mrpc',\n",
    "        max_seq_length: int = 128,\n",
    "        train_batch_size: int = 32,\n",
    "        eval_batch_size: int = 32,\n",
    "        **kwargs\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.model_name_or_path = model_name_or_path\n",
    "        self.task_name = task_name\n",
    "        self.max_seq_length = max_seq_length\n",
    "        self.train_batch_size = train_batch_size\n",
    "        self.eval_batch_size = eval_batch_size\n",
    "\n",
    "        self.text_fields = self.task_text_field_map[task_name]\n",
    "        self.num_labels = self.glue_task_num_labels[task_name]\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
    "\n",
    "    def setup(self, stage):\n",
    "        self.dataset = datasets.load_dataset('glue', self.task_name)\n",
    "\n",
    "        for split in self.dataset.keys():\n",
    "            self.dataset[split] = self.dataset[split].map(\n",
    "                self.convert_to_features,\n",
    "                batched=True,\n",
    "                remove_columns=['label'],\n",
    "            )\n",
    "            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
    "            self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
    "\n",
    "        self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
    "\n",
    "    def prepare_data(self):\n",
    "        datasets.load_dataset('glue', self.task_name)\n",
    "        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
    "    \n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
    "    \n",
    "    def val_dataloader(self):\n",
    "        if len(self.eval_splits) == 1:\n",
    "            return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
    "        elif len(self.eval_splits) > 1:\n",
    "            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        if len(self.eval_splits) == 1:\n",
    "            return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
    "        elif len(self.eval_splits) > 1:\n",
    "            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
    "\n",
    "    def convert_to_features(self, example_batch, indices=None):\n",
    "\n",
    "        # Either encode single sentence or sentence pairs\n",
    "        if len(self.text_fields) > 1:\n",
    "            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
    "        else:\n",
    "            texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
    "\n",
    "        # Tokenize the text/text pairs\n",
    "        features = self.tokenizer.batch_encode_plus(\n",
    "            texts_or_text_pairs,\n",
    "            max_length=self.max_seq_length,\n",
    "            pad_to_max_length=True,\n",
    "            truncation=True\n",
    "        )\n",
    "\n",
    "        # Rename label to labels to make it easier to pass to model forward\n",
    "        features['labels'] = example_batch['label']\n",
    "\n",
    "        return features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jQC3a6KuOpX3"
   },
   "source": [
    "#### You could use this datamodule with standalone PyTorch if you wanted..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JCMH3IAsNffF"
   },
   "outputs": [],
   "source": [
    "dm = GLUEDataModule('distilbert-base-uncased')\n",
    "dm.prepare_data()\n",
    "dm.setup('fit')\n",
    "next(iter(dm.train_dataloader()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "l9fQ_67BO2Lj"
   },
   "source": [
    "## GLUE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gtn5YGKYO65B"
   },
   "outputs": [],
   "source": [
    "class GLUETransformer(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_name_or_path: str,\n",
    "        num_labels: int,\n",
    "        learning_rate: float = 2e-5,\n",
    "        adam_epsilon: float = 1e-8,\n",
    "        warmup_steps: int = 0,\n",
    "        weight_decay: float = 0.0,\n",
    "        train_batch_size: int = 32,\n",
    "        eval_batch_size: int = 32,\n",
    "        eval_splits: Optional[list] = None,\n",
    "        **kwargs\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
    "        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
    "        self.metric = datasets.load_metric(\n",
    "            'glue',\n",
    "            self.hparams.task_name,\n",
    "            experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
    "        )\n",
    "\n",
    "    def forward(self, **inputs):\n",
    "        return self.model(**inputs)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        outputs = self(**batch)\n",
    "        loss = outputs[0]\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
    "        outputs = self(**batch)\n",
    "        val_loss, logits = outputs[:2]\n",
    "\n",
    "        if self.hparams.num_labels >= 1:\n",
    "            preds = torch.argmax(logits, axis=1)\n",
    "        elif self.hparams.num_labels == 1:\n",
    "            preds = logits.squeeze()\n",
    "\n",
    "        labels = batch[\"labels\"]\n",
    "\n",
    "        return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
    "\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        if self.hparams.task_name == 'mnli':\n",
    "            for i, output in enumerate(outputs):\n",
    "                # matched or mismatched\n",
    "                split = self.hparams.eval_splits[i].split('_')[-1]\n",
    "                preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
    "                labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
    "                loss = torch.stack([x['loss'] for x in output]).mean()\n",
    "                self.log(f'val_loss_{split}', loss, prog_bar=True)\n",
    "                split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
    "                self.log_dict(split_metrics, prog_bar=True)\n",
    "            return loss\n",
    "\n",
    "        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
    "        labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
    "        loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
    "        return loss\n",
    "\n",
    "    def setup(self, stage):\n",
    "        if stage == 'fit':\n",
    "            # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
    "            train_loader = self.train_dataloader()\n",
    "\n",
    "            # Calculate total steps\n",
    "            self.total_steps = (\n",
    "                (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
    "                // self.hparams.accumulate_grad_batches\n",
    "                * float(self.hparams.max_epochs)\n",
    "            )\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
    "        model = self.model\n",
    "        no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
    "        optimizer_grouped_parameters = [\n",
    "            {\n",
    "                \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
    "                \"weight_decay\": self.hparams.weight_decay,\n",
    "            },\n",
    "            {\n",
    "                \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
    "                \"weight_decay\": 0.0,\n",
    "            },\n",
    "        ]\n",
    "        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
    "\n",
    "        scheduler = get_linear_schedule_with_warmup(\n",
    "            optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
    "        )\n",
    "        scheduler = {\n",
    "            'scheduler': scheduler,\n",
    "            'interval': 'step',\n",
    "            'frequency': 1\n",
    "        }\n",
    "        return [optimizer], [scheduler]\n",
    "\n",
    "    @staticmethod\n",
    "    def add_model_specific_args(parent_parser):\n",
    "        parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
    "        parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
    "        parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
    "        parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
    "        parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
    "        return parser"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ha-NdIP_xbd3"
   },
   "source": [
    "### ⚡ Quick Tip \n",
    "  - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3dEHnl3RPlAR"
   },
   "outputs": [],
   "source": [
    "def parse_args(args=None):\n",
    "    parser = ArgumentParser()\n",
    "    parser = pl.Trainer.add_argparse_args(parser)\n",
    "    parser = GLUEDataModule.add_argparse_args(parser)\n",
    "    parser = GLUETransformer.add_model_specific_args(parser)\n",
    "    parser.add_argument('--seed', type=int, default=42)\n",
    "    return parser.parse_args(args)\n",
    "\n",
    "\n",
    "def main(args):\n",
    "    pl.seed_everything(args.seed)\n",
    "    dm = GLUEDataModule.from_argparse_args(args)\n",
    "    dm.prepare_data()\n",
    "    dm.setup('fit')\n",
    "    model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
    "    trainer = pl.Trainer.from_argparse_args(args)\n",
    "    return dm, model, trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PkuLaeec3sJ-"
   },
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QSpueK5UPsN7"
   },
   "source": [
    "## CoLA\n",
    "\n",
    "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NJnFmtpnPu0Y"
   },
   "outputs": [],
   "source": [
    "mocked_args = \"\"\"\n",
    "    --model_name_or_path albert-base-v2\n",
    "    --task_name cola\n",
    "    --max_epochs 3\n",
    "    --gpus 1\"\"\".split()\n",
    "\n",
    "args = parse_args(mocked_args)\n",
    "dm, model, trainer = main(args)\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_MrNsTnqdz4z"
   },
   "source": [
    "## MRPC\n",
    "\n",
    "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LBwRxg9Cb3d-"
   },
   "outputs": [],
   "source": [
    "mocked_args = \"\"\"\n",
    "    --model_name_or_path distilbert-base-cased\n",
    "    --task_name mrpc\n",
    "    --max_epochs 3\n",
    "    --gpus 1\"\"\".split()\n",
    "\n",
    "args = parse_args(mocked_args)\n",
    "dm, model, trainer = main(args)\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iZhbn0HzfdCu"
   },
   "source": [
    "## MNLI\n",
    "\n",
    " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
    "\n",
    " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
    "\n",
    "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AvsZMOggfcWW"
   },
   "outputs": [],
   "source": [
    "mocked_args = \"\"\"\n",
    "    --model_name_or_path distilbert-base-uncased\n",
    "    --task_name mnli\n",
    "    --max_epochs 1\n",
    "    --gpus 1\n",
    "    --limit_train_batches 10\n",
    "    --progress_bar_refresh_rate 20\"\"\".split()\n",
    "\n",
    "args = parse_args(mocked_args)\n",
    "dm, model, trainer = main(args)\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<code style=\"color:#792ee5;\">\n",
    "    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>\n",
    "</code>\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
    "\n",
    "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
    "\n",
    "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
    "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
    "\n",
    "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
    "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
    "\n",
    "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
    "\n",
    "### Contributions !\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
    "\n",
    "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* You can also contribute your own notebooks with useful examples !\n",
    "\n",
    "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
    "\n",
    "<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "04-transformers-text-classification.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
